ä¹³č ŗēčÆęé¢ęµåę¶
1. 锹ē®ę¦čæ°Ā¶
2. ę°ę®å¤ē¶
2.1 ę°ę®å č½½äøé¢č§Ā¶
- å č½½CSVę°ę®é
- äŗ¤äŗå¼č”Øę ¼å±ē¤ŗę°ę®ę ·ę¬
- åŗę¬ē»č®”äæ”ęÆåę
2.2 ę°ę®é¢å¤ē¶
- å é¤ę å ³å(ID)å缺失å¼
- ę ē¾ę°å¼å(Mā1, Bā0)
- åå±ę½ę ·ååč®ē»é/ęµčÆé
- SMOTEčæéę ·å¤ēē±»å«äøå¹³č””
- ē¹å¾ę åå
2.3 ę°ę®åÆč§å¶
- čÆęē»ęååøé„¼å¾
- ē¹å¾ååøē®±ēŗæå¾ļ¼åē»å±ē¤ŗļ¼
- å ³é®ē¹å¾čåååøē©éµ
- ē¹å¾ēøå ³ę§ēåå¾
- RFE ē¹å¾éę©
2.4 ē¹å¾å·„ēØĀ¶
- å¼åøøå¼ę£ęµļ¼Z-score + IQR + Isolation Forestļ¼
- ä½ę¹å·®ē¹å¾čæę»¤
- é«ēøå ³ę§ē¹å¾ē§»é¤
- RFEē¹å¾éę©äøåÆč§å
2.5 樔åę建äøčÆä¼°Ā¶
- åŗēŗæęØ”åļ¼é»č¾åå½+SMOTE
- ä¼å樔åļ¼XGBoost + RandomizedSearchCV + EarlyStopping + Threshold Optimization
- ę§č½åƹęÆļ¼åē”®ēćå¬åēćF1å¼ē
- ROCę²ēŗæäøPRę²ēŗæåę
3. å ³é®ē»ę¶
- ę佳樔åę§č½åƹęÆ
- ē¹å¾éč¦ę§åę
- é结ęęčÆä¼°
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve
import plotly.graph_objects as go
from plotly.offline import init_notebook_mode
init_notebook_mode(connected=False)
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
import plotly.io as pio
pio.renderers.default = "plotly_mimetype+notebook"
df = pd.read_csv('data.csv')
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)
from IPython.display import HTML
styled_html = df.head(10).style \
.set_properties(**{'text-align': 'center'}) \
.set_table_styles([
{'selector': 'th', 'props': [('background-color', '#404040'), ('color', 'white')]},
{'selector': 'td', 'props': [('border', '1px solid #dee2e6')]}
]) \
.format({
'radius_mean': '{:.2f}', 'texture_mean': '{:.2f}', 'area_mean': '{:.2f}',
'smoothness_mean': '{:.4f}', 'compactness_mean': '{:.4f}', 'concavity_mean': '{:.4f}',
'concave points_mean': '{:.4f}', 'symmetry_mean': '{:.4f}', 'fractal_dimension_mean': '{:.4f}',
}) \
.highlight_max(color='lightgreen') \
.highlight_min(color='salmon') \
.to_html()
html_with_scroll = f"""
<div style='overflow-x: auto; max-width: 100%;'>
{styled_html}
</div>
"""
display(HTML(html_with_scroll))
print("\nę°ę®åŗę¬ē»č®”äæ”ęÆļ¼")
styled_stats = df.describe().style.format('{:.2f}')
stats_html = styled_stats.to_html()
stats_with_scroll = f"""
<div style='overflow-x: auto; max-width: 100%;'>
{stats_html}
</div>
"""
display(HTML(stats_with_scroll))
| Ā | id | diagnosis | radius_mean | texture_mean | perimeter_mean | area_mean | smoothness_mean | compactness_mean | concavity_mean | concave points_mean | symmetry_mean | fractal_dimension_mean | radius_se | texture_se | perimeter_se | area_se | smoothness_se | compactness_se | concavity_se | concave points_se | symmetry_se | fractal_dimension_se | radius_worst | texture_worst | perimeter_worst | area_worst | smoothness_worst | compactness_worst | concavity_worst | concave points_worst | symmetry_worst | fractal_dimension_worst | Unnamed: 32 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 842302 | M | 17.99 | 10.38 | 122.800000 | 1001.00 | 0.1184 | 0.2776 | 0.3001 | 0.1471 | 0.2419 | 0.0787 | 1.095000 | 0.905300 | 8.589000 | 153.400000 | 0.006399 | 0.049040 | 0.053730 | 0.015870 | 0.030030 | 0.006193 | 25.380000 | 17.330000 | 184.600000 | 2019.000000 | 0.162200 | 0.665600 | 0.711900 | 0.265400 | 0.460100 | 0.118900 | nan |
| 1 | 842517 | M | 20.57 | 17.77 | 132.900000 | 1326.00 | 0.0847 | 0.0786 | 0.0869 | 0.0702 | 0.1812 | 0.0567 | 0.543500 | 0.733900 | 3.398000 | 74.080000 | 0.005225 | 0.013080 | 0.018600 | 0.013400 | 0.013890 | 0.003532 | 24.990000 | 23.410000 | 158.800000 | 1956.000000 | 0.123800 | 0.186600 | 0.241600 | 0.186000 | 0.275000 | 0.089020 | nan |
| 2 | 84300903 | M | 19.69 | 21.25 | 130.000000 | 1203.00 | 0.1096 | 0.1599 | 0.1974 | 0.1279 | 0.2069 | 0.0600 | 0.745600 | 0.786900 | 4.585000 | 94.030000 | 0.006150 | 0.040060 | 0.038320 | 0.020580 | 0.022500 | 0.004571 | 23.570000 | 25.530000 | 152.500000 | 1709.000000 | 0.144400 | 0.424500 | 0.450400 | 0.243000 | 0.361300 | 0.087580 | nan |
| 3 | 84348301 | M | 11.42 | 20.38 | 77.580000 | 386.10 | 0.1425 | 0.2839 | 0.2414 | 0.1052 | 0.2597 | 0.0974 | 0.495600 | 1.156000 | 3.445000 | 27.230000 | 0.009110 | 0.074580 | 0.056610 | 0.018670 | 0.059630 | 0.009208 | 14.910000 | 26.500000 | 98.870000 | 567.700000 | 0.209800 | 0.866300 | 0.686900 | 0.257500 | 0.663800 | 0.173000 | nan |
| 4 | 84358402 | M | 20.29 | 14.34 | 135.100000 | 1297.00 | 0.1003 | 0.1328 | 0.1980 | 0.1043 | 0.1809 | 0.0588 | 0.757200 | 0.781300 | 5.438000 | 94.440000 | 0.011490 | 0.024610 | 0.056880 | 0.018850 | 0.017560 | 0.005115 | 22.540000 | 16.670000 | 152.200000 | 1575.000000 | 0.137400 | 0.205000 | 0.400000 | 0.162500 | 0.236400 | 0.076780 | nan |
| 5 | 843786 | M | 12.45 | 15.70 | 82.570000 | 477.10 | 0.1278 | 0.1700 | 0.1578 | 0.0809 | 0.2087 | 0.0761 | 0.334500 | 0.890200 | 2.217000 | 27.190000 | 0.007510 | 0.033450 | 0.036720 | 0.011370 | 0.021650 | 0.005082 | 15.470000 | 23.750000 | 103.400000 | 741.600000 | 0.179100 | 0.524900 | 0.535500 | 0.174100 | 0.398500 | 0.124400 | nan |
| 6 | 844359 | M | 18.25 | 19.98 | 119.600000 | 1040.00 | 0.0946 | 0.1090 | 0.1127 | 0.0740 | 0.1794 | 0.0574 | 0.446700 | 0.773200 | 3.180000 | 53.910000 | 0.004314 | 0.013820 | 0.022540 | 0.010390 | 0.013690 | 0.002179 | 22.880000 | 27.660000 | 153.200000 | 1606.000000 | 0.144200 | 0.257600 | 0.378400 | 0.193200 | 0.306300 | 0.083680 | nan |
| 7 | 84458202 | M | 13.71 | 20.83 | 90.200000 | 577.90 | 0.1189 | 0.1645 | 0.0937 | 0.0599 | 0.2196 | 0.0745 | 0.583500 | 1.377000 | 3.856000 | 50.960000 | 0.008805 | 0.030290 | 0.024880 | 0.014480 | 0.014860 | 0.005412 | 17.060000 | 28.140000 | 110.600000 | 897.000000 | 0.165400 | 0.368200 | 0.267800 | 0.155600 | 0.319600 | 0.115100 | nan |
| 8 | 844981 | M | 13.00 | 21.82 | 87.500000 | 519.80 | 0.1273 | 0.1932 | 0.1859 | 0.0935 | 0.2350 | 0.0739 | 0.306300 | 1.002000 | 2.406000 | 24.320000 | 0.005731 | 0.035020 | 0.035530 | 0.012260 | 0.021430 | 0.003749 | 15.490000 | 30.730000 | 106.200000 | 739.300000 | 0.170300 | 0.540100 | 0.539000 | 0.206000 | 0.437800 | 0.107200 | nan |
| 9 | 84501001 | M | 12.46 | 24.04 | 83.970000 | 475.90 | 0.1186 | 0.2396 | 0.2273 | 0.0854 | 0.2030 | 0.0824 | 0.297600 | 1.599000 | 2.039000 | 23.940000 | 0.007149 | 0.072170 | 0.077430 | 0.014320 | 0.017890 | 0.010080 | 15.090000 | 40.680000 | 97.650000 | 711.400000 | 0.185300 | 1.058000 | 1.105000 | 0.221000 | 0.436600 | 0.207500 | nan |
ę°ę®åŗę¬ē»č®”äæ”ęÆļ¼
| Ā | id | radius_mean | texture_mean | perimeter_mean | area_mean | smoothness_mean | compactness_mean | concavity_mean | concave points_mean | symmetry_mean | fractal_dimension_mean | radius_se | texture_se | perimeter_se | area_se | smoothness_se | compactness_se | concavity_se | concave points_se | symmetry_se | fractal_dimension_se | radius_worst | texture_worst | perimeter_worst | area_worst | smoothness_worst | compactness_worst | concavity_worst | concave points_worst | symmetry_worst | fractal_dimension_worst | Unnamed: 32 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 569.00 | 0.00 |
| mean | 30371831.43 | 14.13 | 19.29 | 91.97 | 654.89 | 0.10 | 0.10 | 0.09 | 0.05 | 0.18 | 0.06 | 0.41 | 1.22 | 2.87 | 40.34 | 0.01 | 0.03 | 0.03 | 0.01 | 0.02 | 0.00 | 16.27 | 25.68 | 107.26 | 880.58 | 0.13 | 0.25 | 0.27 | 0.11 | 0.29 | 0.08 | nan |
| std | 125020585.61 | 3.52 | 4.30 | 24.30 | 351.91 | 0.01 | 0.05 | 0.08 | 0.04 | 0.03 | 0.01 | 0.28 | 0.55 | 2.02 | 45.49 | 0.00 | 0.02 | 0.03 | 0.01 | 0.01 | 0.00 | 4.83 | 6.15 | 33.60 | 569.36 | 0.02 | 0.16 | 0.21 | 0.07 | 0.06 | 0.02 | nan |
| min | 8670.00 | 6.98 | 9.71 | 43.79 | 143.50 | 0.05 | 0.02 | 0.00 | 0.00 | 0.11 | 0.05 | 0.11 | 0.36 | 0.76 | 6.80 | 0.00 | 0.00 | 0.00 | 0.00 | 0.01 | 0.00 | 7.93 | 12.02 | 50.41 | 185.20 | 0.07 | 0.03 | 0.00 | 0.00 | 0.16 | 0.06 | nan |
| 25% | 869218.00 | 11.70 | 16.17 | 75.17 | 420.30 | 0.09 | 0.06 | 0.03 | 0.02 | 0.16 | 0.06 | 0.23 | 0.83 | 1.61 | 17.85 | 0.01 | 0.01 | 0.02 | 0.01 | 0.02 | 0.00 | 13.01 | 21.08 | 84.11 | 515.30 | 0.12 | 0.15 | 0.11 | 0.06 | 0.25 | 0.07 | nan |
| 50% | 906024.00 | 13.37 | 18.84 | 86.24 | 551.10 | 0.10 | 0.09 | 0.06 | 0.03 | 0.18 | 0.06 | 0.32 | 1.11 | 2.29 | 24.53 | 0.01 | 0.02 | 0.03 | 0.01 | 0.02 | 0.00 | 14.97 | 25.41 | 97.66 | 686.50 | 0.13 | 0.21 | 0.23 | 0.10 | 0.28 | 0.08 | nan |
| 75% | 8813129.00 | 15.78 | 21.80 | 104.10 | 782.70 | 0.11 | 0.13 | 0.13 | 0.07 | 0.20 | 0.07 | 0.48 | 1.47 | 3.36 | 45.19 | 0.01 | 0.03 | 0.04 | 0.01 | 0.02 | 0.00 | 18.79 | 29.72 | 125.40 | 1084.00 | 0.15 | 0.34 | 0.38 | 0.16 | 0.32 | 0.09 | nan |
| max | 911320502.00 | 28.11 | 39.28 | 188.50 | 2501.00 | 0.16 | 0.35 | 0.43 | 0.20 | 0.30 | 0.10 | 2.87 | 4.88 | 21.98 | 542.20 | 0.03 | 0.14 | 0.40 | 0.05 | 0.08 | 0.03 | 36.04 | 49.54 | 251.20 | 4254.00 | 0.22 | 1.06 | 1.25 | 0.29 | 0.66 | 0.21 | nan |
2.2 ę°ę®é¢å¤ē¶
ę°ę®ęø ę“ćę ē¾ē¼ē ćå¤ēē±»å«äøå¹³č””ćē¹å¾ę åå
data = pd.read_csv('data.csv')
data = data.drop(['id', 'Unnamed: 32'], axis=1, errors='ignore') # å é¤ę ēØå
data = data.dropna() # å é¤å«ē¼ŗå¤±å¼ēč”
data['diagnosis'] = data['diagnosis'].map({'M': 1, 'B': 0}) # ę ē¾ę°å¼å
X = data.drop('diagnosis', axis=1)
y = data['diagnosis']
print("ę°ę®ęø
ę“å®ęļ¼ę ·ę¬ę°:", data.shape[0], "ē¹å¾ę°:", X.shape[1])
# åå²č®ē»éåęµčÆéļ¼åØ SMOTE åę ååä¹åļ¼
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
# 对č®ē»éåŗēØ SMOTE
from imblearn.over_sampling import SMOTE
smote = SMOTE(random_state=42)
X_train_smote, y_train_smote = smote.fit_resample(X_train, y_train)
# ę ååē¹å¾
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_smote = scaler.fit_transform(X_train_smote)
X_test = scaler.transform(X_test)
print("ę°ę®é¢å¤ēå®ęļ¼č®ē»éę ·ę¬ę°:", X_train_smote.shape[0], "ęµčÆéę ·ę¬ę°:", X_test.shape[0])
ę°ę®ęø ę“å®ęļ¼ę ·ę¬ę°: 569 ē¹å¾ę°: 30 ę°ę®é¢å¤ēå®ęļ¼č®ē»éę ·ę¬ę°: 570 ęµčÆéę ·ę¬ę°: 114
2.3 ę°ę®åÆč§å¶
å建交äŗå¼é„¼å¾å±ē¤ŗčÆę§/ę¶ę§ē ä¾ååøęÆä¾
import plotly.graph_objects as go
from IPython.display import HTML
import pandas as pd
# 示ä¾ę°ę®
y = pd.Series([0, 1, 0, 0, 1, 1, 0, 1, 1, 1]) # å设 0 蔨示čÆę§ļ¼1 蔨示ę¶ę§
labels = ['čÆę§', 'ę¶ę§']
values = y.value_counts().sort_index()
fig = go.Figure(data=[go.Pie(
labels=labels,
values=values,
textinfo='label+percent+value',
textposition='inside',
insidetextorientation='radial',
marker=dict(colors=['#40E0D0', '#FFD700'], line=dict(color='#FFFFFF', width=2)),
hoverinfo='label+percent+value',
hole=0.2,
)])
fig.update_layout(
title_text='čÆęē±»åååøļ¼čÆę§ vs ę¶ę§ļ¼-é»äŗēæ',
title_font_size=20,
title_x=0.5,
legend_title_text='čÆęē±»å',
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5),
showlegend=True,
paper_bgcolor='white',
plot_bgcolor='white',
font=dict(color='black')
)
html_content = fig.to_html(full_html=False, include_plotlyjs='cdn')
display(HTML(f'<div style="width:800px; margin:0 auto; background-color: white">{html_content}</div>'))
åē»ļ¼5äøŖē¹å¾/ē»ļ¼å建ę åååēē¹å¾ååøē®±ēŗæå¾ļ¼å±ē¤ŗę°ę®ååøåå¼åøøå¼
from sklearn.preprocessing import MinMaxScaler
import plotly.graph_objects as go
import pandas as pd
from IPython.display import HTML
scaler = MinMaxScaler()
X_scaled = scaler.fit_transform(X)
X_scaled_df = pd.DataFrame(X_scaled, columns=X.columns)
colors = ['indianred', 'mediumseagreen', 'dodgerblue', 'plum', 'darkkhaki',
'lightsalmon', 'gold', 'mediumturquoise', 'darkorange', 'lightgreen']
all_colors = colors * (len(X.columns) // len(colors) + 1)
# å建äøäøŖå蔨ę„ååØęęēęē HTML å
容
html_contents = []
for group_idx in range(0, len(X_scaled_df.columns), 5):
group_cols = X_scaled_df.columns[group_idx:group_idx + 5]
fig = go.Figure()
for i, (col, color) in enumerate(zip(group_cols, all_colors[:len(group_cols)])):
fig.add_trace(go.Box(
y=X_scaled_df[col],
name=col,
boxpoints='outliers',
jitter=0.5,
pointpos=0,
whiskerwidth=0.2,
fillcolor=color,
marker=dict(size=3, color=color, line=dict(width=1, color='black')),
line=dict(width=2, color='black'),
opacity=0.8,
hovertemplate=f'ē¹å¾: {col}<br>ę ååå¼: %{{y}}<extra></extra>'
))
fig.update_layout(
title=f'ä¹³č
ŗēē¹å¾ę ååååøå¾ - ē» {group_idx // 5 + 1}',
title_font_size=16,
yaxis_title='ę ååå¼ (0-1)',
xaxis=dict(tickangle=45, tickfont=dict(size=10, color='black')),
showlegend=True,
height=500,
width=1200,
margin=dict(l=40, r=40, t=50, b=80),
plot_bgcolor='white',
paper_bgcolor='white',
font=dict(color='black'),
boxmode='group',
boxgroupgap=0.05,
boxgap=0.2
)
# å°å¾č”Øč½¬ę¢äøŗ HTML å¹¶ååØ
html_content = fig.to_html(full_html=False, include_plotlyjs='cdn')
html_contents.append(html_content)
print(f"å·²ēęļ¼ä¹³č
ŗēē¹å¾ååøē®±åå¾ļ¼ē¬¬ {group_idx // 5 + 1} ē»ļ¼{', '.join(group_cols)}ļ¼ć")
# ę¾ē¤ŗęęäŗ¤äŗå¼å¾č”Ø
for i, html_content in enumerate(html_contents):
display(HTML(f'<h3>ä¹³č
ŗēē¹å¾ę ååååøå¾-é»äŗēæ - ē» {i+1}</h3>'))
display(HTML(f'<div style="width:1200px; margin:0 auto">{html_content}</div>'))
å·²ēęļ¼ä¹³č ŗēē¹å¾ååøē®±åå¾ļ¼ē¬¬ 1 ē»ļ¼radius_mean, texture_mean, perimeter_mean, area_mean, smoothness_meanļ¼ć å·²ēęļ¼ä¹³č ŗēē¹å¾ååøē®±åå¾ļ¼ē¬¬ 2 ē»ļ¼compactness_mean, concavity_mean, concave points_mean, symmetry_mean, fractal_dimension_meanļ¼ć å·²ēęļ¼ä¹³č ŗēē¹å¾ååøē®±åå¾ļ¼ē¬¬ 3 ē»ļ¼radius_se, texture_se, perimeter_se, area_se, smoothness_seļ¼ć å·²ēęļ¼ä¹³č ŗēē¹å¾ååøē®±åå¾ļ¼ē¬¬ 4 ē»ļ¼compactness_se, concavity_se, concave points_se, symmetry_se, fractal_dimension_seļ¼ć å·²ēęļ¼ä¹³č ŗēē¹å¾ååøē®±åå¾ļ¼ē¬¬ 5 ē»ļ¼radius_worst, texture_worst, perimeter_worst, area_worst, smoothness_worstļ¼ć å·²ēęļ¼ä¹³č ŗēē¹å¾ååøē®±åå¾ļ¼ē¬¬ 6 ē»ļ¼compactness_worst, concavity_worst, concave points_worst, symmetry_worst, fractal_dimension_worstļ¼ć
ä¹³č ŗēē¹å¾ę ååååøå¾-é»äŗēæ - ē» 1
ä¹³č ŗēē¹å¾ę ååååøå¾-é»äŗēæ - ē» 2
ä¹³č ŗēē¹å¾ę ååååøå¾-é»äŗēæ - ē» 3
ä¹³č ŗēē¹å¾ę ååååøå¾-é»äŗēæ - ē» 4
ä¹³č ŗēē¹å¾ę ååååøå¾-é»äŗēæ - ē» 5
ä¹³č ŗēē¹å¾ę ååååøå¾-é»äŗēæ - ē» 6
éę©äøčÆęēøå ³ę§ęé«ē4äøŖē¹å¾
å建ę£ē¹ē©éµå±ē¤ŗē¹å¾é“å ³ē³»
å建ē“ę¹å¾å±ē¤ŗē¹å¾č¾¹ē¼ååø
import plotly.express as px
import numpy as np
import seaborn as sns
from IPython.display import HTML
# 1. ę£ē¹ē©éµå¾
correlation = data.corr()['diagnosis'].abs().sort_values(ascending=False)[1:5]
key_features = correlation.index.tolist()
fig_scatter = px.scatter_matrix(
data,
dimensions=key_features,
color='diagnosis',
color_continuous_scale=['#FF9999', '#99FF99'],
title='å
³é®ē¹å¾čåååøäøč¾¹ē¼ååø-é»äŗēæ',
labels={col: col.replace('_', ' ').title() for col in key_features},
height=800,
width=1000
)
fig_scatter.update_traces(
diagonal_visible=True,
showupperhalf=False,
marker=dict(size=6, opacity=0.6)
)
fig_scatter.update_layout(
title_font_size=20,
title_x=0.5,
plot_bgcolor='white',
paper_bgcolor='white',
font=dict(size=10, color='black'),
legend_title_text='čÆęē»ę',
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5)
)
# 2. ē“ę¹å¾
fig_hist = px.histogram(
data,
x=key_features,
color='diagnosis',
marginal="box",
opacity=0.7,
barmode='overlay',
height=400,
width=1000,
title='å
³é®ē¹å¾č¾¹ē¼ååøäøē®±åå¾-é»äŗēæ'
)
fig_hist.update_layout(
plot_bgcolor='white',
paper_bgcolor='white',
font=dict(size=10, color='black'),
legend_title_text='čÆęē»ę',
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5)
)
scatter_html = fig_scatter.to_html(full_html=False, include_plotlyjs='cdn')
hist_html = fig_hist.to_html(full_html=False, include_plotlyjs='cdn')
display(HTML("""
<style>
.plot-container {
margin: 20px auto;
border: 1px solid #eee;
box-shadow: 0 0 10px rgba(0,0,0,0.1);
padding: 15px;
}
</style>
"""))
display(HTML('<h2 style="text-align:center">äŗ¤äŗå¼åÆč§ååę</h2>'))
display(HTML('<div class="plot-container">' + scatter_html + '</div>'))
display(HTML('<div class="plot-container">' + hist_html + '</div>'))
fig_scatter.write_html("scatter_matrix.html", full_html=True)
fig_hist.write_html("histogram_boxplot.html", full_html=True)
äŗ¤äŗå¼åÆč§ååę
å建ęęē¹å¾ēēøå ³ę§ēåå¾ļ¼čÆå«é«åŗ¦ēøå ³ē¹å¾
corr_matrix = X.corr()
plt.figure(figsize=(12, 10))
sns.heatmap(corr_matrix, annot=False, cmap='coolwarm')
plt.title('ē¹å¾ēøå
³ę§ēå¾-é»äŗēæ')
# plt.show()
fig.write_html("chart.html", include_plotlyjs='cdn')
å建å°ęē“å¾å±ē¤ŗå ³é®ē¹å¾åØčÆ/ę¶ę§čÆęäøēååøå·®å¼
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
# å
³é®ē¹å¾éę©
key_features = ['radius_mean', 'texture_mean', 'area_mean', 'concavity_mean']
feature_names = {
'radius_mean': 'å¹³ååå¾',
'texture_mean': 'å¹³åēŗ¹ē',
'area_mean': 'å¹³åé¢ē§Æ',
'concavity_mean': 'å¹³åå¹é·åŗ¦',
'diagnosis': 'čÆęē»ę'
}
scaler = StandardScaler()
data_scaled = data.copy()
data_scaled[key_features] = scaler.fit_transform(data[key_features])
data_melted = pd.melt(data_scaled, id_vars='diagnosis', value_vars=key_features, var_name='variable', value_name='value')
plt.figure(figsize=(12, 6))
sns.violinplot(x='variable', y='value', hue='diagnosis', data=data_melted, split=True, inner='quart')
plt.title('ę åååēå
³é®ē¹å¾å°ęē“å¾-é»äŗēæ')
plt.xticks(rotation=45)
current_labels = [label.get_text() for label in plt.gca().get_xticklabels()]
plt.xticks(range(len(current_labels)), [feature_names[label] for label in current_labels])
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, ['čÆę§', 'ę¶ę§'], title='čÆęē»ę')
plt.xlabel('ē¹å¾')
plt.ylabel('ę ååå¼')
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.tight_layout()
plt.savefig('violin_plot_key_features_chinese.png', dpi=300, bbox_inches='tight')
plt.show()
ä½æēØ StandardScaler 对ē¹å¾ę°ę® X čæč”ę ååå¤ē仄ę¶é¤äøåē¹å¾éēŗ²ēå½±åļ¼å¹¶å°ę°ę®éę 80% å 20% ēęÆä¾ååäøŗč®ē»éåęµčÆéćåØ RFE ē¹å¾éę©é¶ę®µļ¼ä»„é»č¾åå½ęØ”åä½äøŗåŗē”åē±»åØļ¼å建 RFE 对豔éå½ę¶é¤ęäøéč¦ēē¹å¾ļ¼č®ē»åéčæ ranking_ å±ę§č·åē¹å¾ęåļ¼ęå 1 äøŗęéč¦ē¹å¾ćē¹å¾ę°éę§č½čÆä¼°ēÆčļ¼éåä» 1 å°å ØéØē¹å¾ēęęåÆč½ę°éļ¼åƹęÆäøŖē¹å¾ę°é kļ¼éåęåå k ēē¹å¾ļ¼å©ēØ 5 ęäŗ¤åéŖčÆčÆä¼°č®ē»éę§č½ļ¼åę¶åØęµčÆéäøčÆä¼°ęØ”ååē”®ēå¹¶č®°å½äŗ¤åéŖčÆäøęµčÆå¾åćē»ęåęę¶ļ¼ē”®å®äŗ¤åéŖčÆåęµčÆéäøę§č½ęä½³ēē¹å¾ę°éļ¼ęå对åŗęä½³ē¹å¾åéē樔åē³»ę°ä½äøŗē¹å¾éč¦ę§ć
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import RFE
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
from IPython.display import HTML
scaler = StandardScaler()
X = data.drop('diagnosis', axis=1)
y = data['diagnosis']
X_scaled = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
# åå§åRFEå樔å
model = LogisticRegression(random_state=42, max_iter=1000)
rfe = RFE(estimator=model, n_features_to_select=1, step=1)
rfe.fit(X_train, y_train)
# č·åē¹å¾ęååęÆęåŗ¦
feature_ranking = rfe.ranking_
feature_support = rfe.support_
# čÆä¼°äøåē¹å¾ę°éäøēę§č½
n_features = X_train.shape[1]
n_features_range = range(1, n_features + 1)
cv_scores = []
test_scores = []
for k in n_features_range:
selected_features = feature_ranking <= k
X_train_selected = X_train[:, selected_features]
X_test_selected = X_test[:, selected_features]
model = LogisticRegression(random_state=42, max_iter=1000)
cv_score = cross_val_score(model, X_train_selected, y_train, cv=5, scoring='accuracy').mean()
model.fit(X_train_selected, y_train)
test_score = accuracy_score(y_test, model.predict(X_test_selected))
cv_scores.append(cv_score)
test_scores.append(test_score)
# ę¾åŗęä½³ē¹å¾ę°é
best_cv_score_idx = np.argmax(cv_scores)
best_test_score_idx = np.argmax(test_scores)
best_cv_k = n_features_range[best_cv_score_idx]
best_test_k = n_features_range[best_test_score_idx]
# č·åęä½³ē¹å¾éēē³»ę°
best_features = feature_ranking <= best_cv_k
model.fit(X_train[:, best_features], y_train)
feature_importances = np.abs(model.coef_[0])
sorted_idx = np.argsort(feature_importances)[::-1][:15]
# å建åÆč§åå¾č”Ø
fig = make_subplots(
rows=3, cols=1,
subplot_titles=(
f'ē¹å¾ęå-é»äŗēæ (1=ęéč¦)',
f'ęä½³ē¹å¾åéęé-é»äŗēæ (k={best_cv_k})',
'ē¹å¾ę°é对樔åę§č½ēå½±å-é»äŗēæ'
),
vertical_spacing=0.15,
horizontal_spacing=0.05
)
# å¾1: ē¹å¾ęå
top_features = np.argsort(feature_ranking)[:15]
fig.add_trace(
go.Bar(
x=[f'ē¹å¾ {i}' for i in top_features],
y=feature_ranking[top_features],
marker_color='royalblue',
name='ē¹å¾ęå',
hovertemplate='ē¹å¾: %{x}<br>ęå: %{y}<extra></extra>'
),
row=1, col=1
)
# å¾2: ē¹å¾éč¦ę§
fig.add_trace(
go.Bar(
x=[f'ē¹å¾ {i}' for i in sorted_idx],
y=feature_importances[sorted_idx],
marker_color='mediumseagreen',
name='ē¹å¾ęé',
hovertemplate='ē¹å¾: %{x}<br>ęé: %{y:.4f}<extra></extra>'
),
row=2, col=1
)
# å¾3: 樔åę§č½
fig.add_trace(
go.Scatter(
x=list(n_features_range),
y=cv_scores,
mode='lines+markers',
line=dict(color='mediumorchid', width=2),
marker=dict(size=8),
name='äŗ¤åéŖčÆå¾å',
hovertemplate='ē¹å¾ę°é: %{x}<br>CVå¾å: %{y:.3f}<extra></extra>'
),
row=3, col=1
)
fig.add_trace(
go.Scatter(
x=list(n_features_range),
y=test_scores,
mode='lines+markers',
line=dict(color='deepskyblue', width=2),
marker=dict(size=8),
name='ęµčÆå¾å',
hovertemplate='ē¹å¾ę°é: %{x}<br>ęµčÆå¾å: %{y:.3f}<extra></extra>'
),
row=3, col=1
)
# ę·»å ęä½³ē¹å¾ę č®°
fig.add_vline(
x=best_cv_k,
line=dict(dash='dash', color='red', width=1.5),
annotation_text=f'ęä½³CV: k={best_cv_k}',
annotation_position='top right',
row=3, col=1
)
fig.add_vline(
x=best_test_k,
line=dict(dash='dash', color='green', width=1.5),
annotation_text=f'ęä½³ęµčÆ: k={best_test_k}',
annotation_position='bottom right',
row=3, col=1
)
# ę“ę°ę“ä½åøå±
fig.update_layout(
height=1200,
width=1000,
title_text='RFEē¹å¾éę©åę',
title_font=dict(size=24, family="Arial", color='black'),
title_x=0.5,
showlegend=True,
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="center",
x=0.5
),
plot_bgcolor='white',
paper_bgcolor='white',
font=dict(size=12, color='black'),
margin=dict(l=80, r=80, t=100, b=80)
)
# ę“ę°åå¾åøå±
fig.update_xaxes(title_text='ē¹å¾ē“¢å¼', row=1, col=1, title_font=dict(size=14))
fig.update_yaxes(title_text='ęå (1=ęéč¦)', row=1, col=1, title_font=dict(size=14))
fig.update_xaxes(title_text='ē¹å¾ē“¢å¼', row=2, col=1, title_font=dict(size=14))
fig.update_yaxes(title_text='ęéē»åƹå¼', row=2, col=1, title_font=dict(size=14))
fig.update_xaxes(title_text='ē¹å¾ę°é', row=3, col=1, title_font=dict(size=14))
fig.update_yaxes(title_text='åē”®ē', row=3, col=1, title_font=dict(size=14))
# ēęäŗ¤äŗå¼HTMLå
容
html_content = fig.to_html(
full_html=False,
include_plotlyjs='cdn',
config={
'responsive': True,
'displayModeBar': True,
'scrollZoom': True
}
)
# å建å±
äøę¾ē¤ŗēHTML容åØ
centered_html = f"""
<div style="
display: flex;
justify-content: center;
align-items: center;
flex-direction: column;
width: 100%;
padding: 20px;
background-color: #f9f9f9;
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
">
<h2 style="text-align: center; color: #333; margin-bottom: 20px;">
RFEē¹å¾éę©äŗ¤äŗå¼åę
</h2>
<div style="width: 1000px; height: 1200px;">
{html_content}
</div>
</div>
"""
display(HTML(centered_html))
RFEē¹å¾éę©äŗ¤äŗå¼åę
2.4 ē¹å¾å·„ēØĀ¶
1. å¼åøøå¼ē»¼åę£ęµäøå¤ē
å¤ę¹ę³ę£ęµļ¼ē»åZ-score(>3Ļ)ćIQR(1.5åč·)åIsolationForest(10%ę±”ęē)čÆå«å¼åøøå¼
ęęčÆä¼°ļ¼åƹęÆē§»é¤ååę°ę®ååøååļ¼åå¼/ę åå·®ļ¼åé»č¾åå½ęØ”ååē”®ēęå
åÆč§åļ¼åē»ē®±ēŗæå¾å±ē¤ŗåå§ę°ę®å¼åøøē¹ååø
2. ē¹å¾ēéä¼å
ä½ę¹å·®čæę»¤ļ¼ē§»é¤ę¹å·®<0.01ēē¹å¾ļ¼å¦ēسå®äøåēę ęęę ļ¼
é«ēøå ³åé¤ļ¼å é¤ēøå ³ē³»ę°>0.9ēåä½ē¹å¾ļ¼éæå å¤éå ±ēŗæę§ļ¼
RFEē¹å¾éę©ļ¼ļ¼éčæéå½ē¹å¾ę¶é¤ē”®å®ē¹å¾éč¦ę§ęåļ¼åØęčÆä¼°äøåē¹å¾ę°éäøē樔åę§č½ļ¼5ęäŗ¤åéŖčÆļ¼ļ¼čŖåØéę©ęä¼ē¹å¾åéļ¼ęä½³k=äŗ¤åéŖčÆå³°å¼åƹåŗē¹å¾ę°ļ¼
3. å¤ē»“åÆč§ååę
ē¹å¾ęåå¾ļ¼ę±ē¶å¾å±ē¤ŗRFEčÆå®ēTOP15éč¦ē¹å¾
ęéåęå¾ļ¼é»č¾åå½ē³»ę°ē»åƹå¼åę ē¹å¾å½±åå
ę§č½å ³ē³»å¾ļ¼åę²ēŗæåƹęÆē¹å¾ę°éäøéŖčÆé/ęµčÆéåē”®ēå ³ē³»
t-SNEé结ļ¼ēØęä½³ē¹å¾åéå®ē°é«ē»“ę°ę®äŗē»“ęå½±ļ¼é¢č²ē¼ē čÆęē»ę
import plotly.io as pio
from IPython.display import HTML, display
import plotly.graph_objects as go
pio.renderers.default = "plotly_mimetype+notebook"
def create_interactive_boxplots(X):
base_colors = ['indianred', 'mediumseagreen', 'dodgerblue', 'plum', 'darkkhaki',
'lightsalmon', 'gold', 'mediumturquoise', 'darkorange', 'lightgreen']
all_colors = base_colors * (len(X.columns) // len(base_colors) + 1)
html_contents = []
for group_idx in range(0, len(X.columns), 5):
group_cols = X.columns[group_idx:group_idx + 5]
fig = go.Figure()
for col, color in zip(group_cols, all_colors[:len(group_cols)]):
fig.add_trace(go.Box(
y=X[col],
name=col,
boxpoints='outliers',
jitter=0.5,
pointpos=0,
whiskerwidth=0.2,
fillcolor=color,
marker=dict(size=3, color=color, line=dict(width=1, color='black')),
line=dict(width=2, color='black'),
opacity=0.8,
hovertemplate=f'ē¹å¾: {col}<br>å¼: %{{y}}<extra></extra>'
))
fig.update_layout(
title=f'å¼åøøå¼å¤ēåē¹å¾ååø-é»äŗēæ - ē» {group_idx//5 + 1}',
height=500,
width=1000,
plot_bgcolor='white',
paper_bgcolor='white',
title_font_color='black',
font=dict(
family="Arial",
size=12,
color="black"
),
xaxis=dict(
tickfont=dict(color='black'),
title_font=dict(color='black')
),
yaxis=dict(
tickfont=dict(color='black'),
title_font=dict(color='black')
)
)
# 转ę¢äøŗHTMLå¹¶ååØ
html_contents.append(fig.to_html(
full_html=False,
include_plotlyjs='cdn',
config={'responsive': True}
))
return html_contents
def display_in_notebook(html_contents):
display(HTML("""
<style>
.plot-container {
margin: 20px auto;
width: 1000px;
box-shadow: 0 0 10px rgba(0,0,0,0.1);
padding: 15px;
}
.plot-title {
text-align: center;
font-size: 18px;
margin: 10px 0;
color: black; # ę·»å é¢č²č®¾ē½®
}
</style>
"""))
for i, html in enumerate(html_contents):
display(HTML(f"""
<div class="plot-title">ē¹å¾ååøå¾ - 第 {i+1} ē»</div>
<div class="plot-container">{html}</div>
"""))
def save_full_html(html_contents, filename):
with open(filename, 'w', encoding='utf-8') as f:
f.write(f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>ä¹³č
ŗēē¹å¾åę</title>
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
<style>
body {{
font-family: Arial;
margin: 20px;
color: black; # ę·»å é¢č²č®¾ē½®
}}
.plot-container {{
margin: 0 auto;
width: 1000px;
}}
h2 {{
color: #2c3e50;
text-align: center;
}}
</style>
</head>
<body>
<h2>ä¹³č
ŗēē¹å¾åęę„å</h2>
{"".join([f'<div class="plot-container">{html}</div>' for html in html_contents])}
</body>
</html>
""")
if __name__ == "__main__":
boxplot_htmls = create_interactive_boxplots(X)
display_in_notebook(boxplot_htmls)
# äæåäøŗå®ę“HTMLęä»¶
save_full_html(boxplot_htmls, "breast_cancer_analysis.html")
print("å·²äæåäøŗ breast_cancer_analysis.html")
å·²äæåäøŗ breast_cancer_analysis.html
2.5 樔åę建äøčÆä¼°Ā¶
åŗēŗæęØ”åļ¼é»č¾åå½+SMOTE
from sklearn.metrics import confusion_matrix, classification_report
log_reg = LogisticRegression(max_iter=5000, random_state=42)
log_reg.fit(X_train, y_train)
y_pred_log_reg = log_reg.predict(X_test)
accuracy_log_reg = accuracy_score(y_test, y_pred_log_reg)
conf_matrix_log_reg = confusion_matrix(y_test, y_pred_log_reg)
class_report_log_reg = classification_report(y_test, y_pred_log_reg, output_dict=True)
print(f"é»č¾åå½ęØ”ååē”®ē: {accuracy_log_reg:.4f}")
print("ę··ę·ē©éµ:")
print(conf_matrix_log_reg)
print("åē±»ę„å:")
print(classification_report(y_test, y_pred_log_reg))
é»č¾åå½ęØ”ååē”®ē: 0.9737
ę··ę·ē©éµ:
[[70 1]
[ 2 41]]
åē±»ę„å:
precision recall f1-score support
0 0.97 0.99 0.98 71
1 0.98 0.95 0.96 43
accuracy 0.97 114
macro avg 0.97 0.97 0.97 114
weighted avg 0.97 0.97 0.97 114
ä¼å樔åļ¼XGBoost + RandomizedSearchCV + EarlyStopping + Threshold Optimization
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.metrics import f1_score, precision_score, recall_score, roc_curve, accuracy_score, confusion_matrix, classification_report
from xgboost import XGBClassifier
import numpy as np
from scipy.stats import uniform, randint
X_train = np.nan_to_num(X_train, nan=0.0)
y_train = np.nan_to_num(y_train, nan=0.0)
scale_pos_weight = len(y_train[y_train==0]) / len(y_train[y_train==1])
xgb = XGBClassifier(
eval_metric='logloss',
random_state=42,
scale_pos_weight=scale_pos_weight,
tree_method='hist'
)
param_dist = {
'n_estimators': randint(50, 300),
'max_depth': randint(2, 10),
'learning_rate': uniform(0.01, 0.3),
'subsample': uniform(0.6, 0.4),
'colsample_bytree': uniform(0.6, 0.4),
'gamma': uniform(0, 0.5),
'reg_alpha': uniform(0, 1),
'reg_lambda': uniform(0, 2),
'min_child_weight': randint(1, 10)
}
random_search = RandomizedSearchCV(
xgb,
param_dist,
n_iter=100,
cv=5,
scoring='f1_weighted',
n_jobs=1,
verbose=1,
random_state=42,
error_score='raise'
)
random_search.fit(X_train, y_train)
print(f"ęä½³åę°: {random_search.best_params_}")
X_train_sub, X_val, y_train_sub, y_val = train_test_split(
X_train, y_train, test_size=0.2, random_state=42, stratify=y_train
)
best_params = random_search.best_params_.copy()
original_n_estimators = best_params.pop('n_estimators', None)
xgb_early = XGBClassifier(
**best_params,
eval_metric='logloss',
random_state=42,
n_estimators=1000,
early_stopping_rounds=50,
tree_method='hist',
device='cpu'
)
xgb_early.fit(
X_train_sub,
y_train_sub,
eval_set=[(X_val, y_val)],
verbose=False
)
best_iter = xgb_early.best_iteration
final_xgb = XGBClassifier(
**best_params,
n_estimators=best_iter,
eval_metric='logloss',
random_state=42,
tree_method='hist',
device='cpu'
)
final_xgb.fit(X_train, y_train)
best_xgb = final_xgb
y_proba = best_xgb.predict_proba(X_test)[:, 1]
fpr, tpr, thresholds = roc_curve(y_test, y_proba)
gmeans = np.sqrt(tpr * (1 - fpr)) # å ä½å¹³åę°
ix = np.argmax(gmeans)
best_thresh = thresholds[ix]
# ę·»å F_betaåę°ä¼åļ¼ę“注éå¬åēļ¼
beta = 1.2 # å¬åēęé>精甮ē
f_beta_scores = []
for thresh in thresholds:
y_pred_temp = (y_proba >= thresh).astype(int)
precision = precision_score(y_test, y_pred_temp, zero_division=0)
recall = recall_score(y_test, y_pred_temp)
if (precision + recall) > 0:
f_beta = (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall)
else:
f_beta = 0
f_beta_scores.append(f_beta)
best_f_beta_ix = np.argmax(f_beta_scores)
best_f_beta_thresh = thresholds[best_f_beta_ix]
# éę©ęä½³éå¼ļ¼ä¼å
F_betaä¼åēéå¼ļ¼
y_pred_xgb = (y_proba >= best_f_beta_thresh).astype(int)
# čÆä¼°ęę ļ¼äæęååéåļ¼
accuracy_xgb = accuracy_score(y_test, y_pred_xgb)
conf_matrix_xgb = confusion_matrix(y_test, y_pred_xgb)
class_report_xgb = classification_report(y_test, y_pred_xgb)
print(f"åå§ęä½³ę ę°é: {original_n_estimators}")
print(f"ę©åęä½³čæä»£ę¬”ę°: {best_iter}")
print(f"åŗäŗG-Meanēęä½³éå¼: {best_thresh:.4f}")
print(f"åŗäŗF{beta}åę°ēéå¼: {best_f_beta_thresh:.4f}")
print(f"ä¼ååēXGBoost樔ååē”®ē: {accuracy_xgb:.4f}")
print("ę··ę·ē©éµ:")
print(conf_matrix_xgb)
print("åē±»ę„å:")
print(class_report_xgb)
print("ę佳樔åå·²äæåäøŗbest_xgb")
Fitting 5 folds for each of 100 candidates, totalling 500 fits
ęä½³åę°: {'colsample_bytree': np.float64(0.8346141738643036), 'gamma': np.float64(0.28222927126132064), 'learning_rate': np.float64(0.12363178778538679), 'max_depth': 6, 'min_child_weight': 6, 'n_estimators': 229, 'reg_alpha': np.float64(0.6459172413316012), 'reg_lambda': np.float64(1.1415566093378238), 'subsample': np.float64(0.7424386903591385)}
åå§ęä½³ę ę°é: 229
ę©åęä½³čæä»£ę¬”ę°: 91
åŗäŗG-Meanēęä½³éå¼: 0.4970
åŗäŗF1.2åę°ēéå¼: 0.4970
ä¼ååēXGBoost樔ååē”®ē: 0.9825
ę··ę·ē©éµ:
[[71 0]
[ 2 41]]
åē±»ę„å:
precision recall f1-score support
0 0.97 1.00 0.99 71
1 1.00 0.95 0.98 43
accuracy 0.98 114
macro avg 0.99 0.98 0.98 114
weighted avg 0.98 0.98 0.98 114
ę佳樔åå·²äæåäøŗbest_xgb
3. å ³é®ē»ę¶
对ęÆé»č¾åå½åXGBoostēę§č½ęę
å建å¤ęę ę±ē¶å¾åƹęÆęØ”åę§č½
ē»å¶ROCę²ēŗæęÆč¾AUCå¼
ē»å¶ē²¾ē”®åŗ¦-å¬åēę²ēŗæ
# 樔å对ęÆéØå
print("樔å对ęÆļ¼")
print(f"é»č¾åå½åē”®ē: {accuracy_log_reg:.4f}")
print(f"ä¼ååēXGBooståē”®ē: {accuracy_xgb:.4f}")
print("\né»č¾åå½ę··ę·ē©éµ:")
print(conf_matrix_log_reg)
print("\nä¼ååēXGBoostę··ę·ē©éµ:")
print(conf_matrix_xgb)
# ē”®äæåē±»ę„å使ēØēøåēę ¼å¼
class_report_log_reg = classification_report(
y_test, y_pred_log_reg,
output_dict=True,
target_names=['0', '1']
)
class_report_xgb = classification_report(
y_test, y_pred_xgb,
output_dict=True,
target_names=['0', '1']
)
print("\né»č¾åå½åē±»ę„å:")
print(classification_report(y_test, y_pred_log_reg))
print("\nä¼ååēXGBooståē±»ę„å:")
print(classification_report(y_test, y_pred_xgb))
# č®”ē®ē¹å¼ę§ (Specificity) = TN / (TN + FP)
def calculate_specificity(conf_matrix):
tn, fp, fn, tp = conf_matrix.ravel()
return tn / (tn + fp) if (tn + fp) > 0 else 0
specificity_log_reg = calculate_specificity(conf_matrix_log_reg)
specificity_xgb = calculate_specificity(conf_matrix_xgb)
# é»č¾åå½ęę ļ¼ē±»å«1ļ¼
precision_log_reg = class_report_log_reg['1']['precision']
recall_log_reg = class_report_log_reg['1']['recall']
f1_log_reg = class_report_log_reg['1']['f1-score']
# XGBoostęę ļ¼ē±»å«1ļ¼
precision_xgb = class_report_xgb['1']['precision']
recall_xgb = class_report_xgb['1']['recall']
f1_xgb = class_report_xgb['1']['f1-score']
metrics = ['åē”®ē', 'å¬åē', 'ē¹å¼ę§', '精甮ē', 'F1å¼']
log_reg_scores = [accuracy_log_reg, recall_log_reg, specificity_log_reg, precision_log_reg, f1_log_reg]
xgb_scores = [accuracy_xgb, recall_xgb, specificity_xgb, precision_xgb, f1_xgb]
x = np.arange(len(metrics))
width = 0.35
fig, ax = plt.subplots(figsize=(12, 7))
bars1 = ax.bar(x - width/2, log_reg_scores, width, label='é»č¾åå½', color='#1f77b4', edgecolor='black')
bars2 = ax.bar(x + width/2, xgb_scores, width, label='XGBoost', color='#ff7f0e', edgecolor='black')
ax.set_xlabel('čÆä¼°ęę ', fontsize=12)
ax.set_ylabel('åå¼', fontsize=12)
ax.set_title('é»č¾åå½äø XGBoost 樔åę§č½åƹęÆ-é»äŗēæ', fontsize=14, fontweight='bold', pad=20)
ax.set_xticks(x)
ax.set_xticklabels(metrics, fontsize=11)
ax.legend(fontsize=11)
ax.grid(True, axis='y', linestyle='--', alpha=0.7)
ax.set_ylim(0.85, 1.0) # ä¼åY轓čå“仄ēŖåŗå·®å¼
# č°ę“ē¾åęÆę ē¾ä½ē½®ļ¼å¢å y轓åē§»é
for bar in bars1:
yval = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2, yval - 0.01,
f'{yval:.1%}',
ha='center', va='bottom', fontsize=9)
for bar in bars2:
yval = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2, yval - 0.01,
f'{yval:.1%}',
ha='center', va='bottom', fontsize=9)
plt.figtext(0.5, 0.01,
f"é»č¾åå½åē”®ē: {accuracy_log_reg:.1%} | XGBooståē”®ē: {accuracy_xgb:.1%} | ęå: {(accuracy_xgb - accuracy_log_reg):.1%}",
ha="center", fontsize=11, bbox=dict(facecolor='lightgray', alpha=0.5))
plt.tight_layout(rect=[0, 0.05, 1, 0.95])
plt.savefig('model_comparison.png', dpi=300, bbox_inches='tight')
plt.show()
# äæ®å¤ROCę²ēŗæéØå - ē”®äæä½æēØę£ē”®ēęµčÆé
from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve
# ē”®äæęµčÆéå¤ēäøč“
X_test_processed = np.nan_to_num(X_test, nan=0.0)
# é»č¾åå½é¢ęµę¦ē
y_pred_prob_log_reg = log_reg.predict_proba(X_test_processed)[:, 1]
# XGBoosté¢ęµę¦ē
y_pred_prob_xgb = best_xgb.predict_proba(X_test_processed)[:, 1]
# č®”ē®AUCå¼
auc_log_reg = roc_auc_score(y_test, y_pred_prob_log_reg)
auc_xgb = roc_auc_score(y_test, y_pred_prob_xgb)
fpr_log_reg, tpr_log_reg, _ = roc_curve(y_test, y_pred_prob_log_reg)
fpr_xgb, tpr_xgb, _ = roc_curve(y_test, y_pred_prob_xgb)
plt.figure(figsize=(10, 6))
plt.plot(fpr_log_reg, tpr_log_reg, label=f'é»č¾åå½ (AUC = {auc_log_reg:.3f})', linewidth=2)
plt.plot(fpr_xgb, tpr_xgb, label=f'XGBoost (AUC = {auc_xgb:.3f})', linewidth=2)
plt.plot([0, 1], [0, 1], 'k--', linewidth=1)
plt.xlabel('åé³ę§ē', fontsize=12)
plt.ylabel('ēé³ę§ē', fontsize=12)
plt.title('ROC ę²ēŗæ-é»äŗēæ', fontsize=14, fontweight='bold', pad=20)
plt.legend(fontsize=11)
plt.grid(True, linestyle='--', alpha=0.7)
plt.savefig('roc_curve.png', dpi=300, bbox_inches='tight')
plt.show()
# ē»å¶ē²¾ē”®åŗ¦-å¬åēę²ēŗæ
precision_log_reg, recall_log_reg, _ = precision_recall_curve(y_test, y_pred_prob_log_reg)
precision_xgb, recall_xgb, _ = precision_recall_curve(y_test, y_pred_prob_xgb)
plt.figure(figsize=(10, 6))
plt.plot(recall_log_reg, precision_log_reg, label='é»č¾åå½', linewidth=2)
plt.plot(recall_xgb, precision_xgb, label='XGBoost', linewidth=2)
plt.xlabel('å¬åē', fontsize=12)
plt.ylabel('精甮度', fontsize=12)
plt.title('精甮度-å¬åēę²ēŗæ-é»äŗēæ', fontsize=14, fontweight='bold', pad=20)
plt.legend(fontsize=11)
plt.grid(True, linestyle='--', alpha=0.7)
plt.savefig('pr_curve.png', dpi=300, bbox_inches='tight')
plt.show()
樔å对ęÆļ¼
é»č¾åå½åē”®ē: 0.9737
ä¼ååēXGBooståē”®ē: 0.9825
é»č¾åå½ę··ę·ē©éµ:
[[70 1]
[ 2 41]]
ä¼ååēXGBoostę··ę·ē©éµ:
[[71 0]
[ 2 41]]
é»č¾åå½åē±»ę„å:
precision recall f1-score support
0 0.97 0.99 0.98 71
1 0.98 0.95 0.96 43
accuracy 0.97 114
macro avg 0.97 0.97 0.97 114
weighted avg 0.97 0.97 0.97 114
ä¼ååēXGBooståē±»ę„å:
precision recall f1-score support
0 0.97 1.00 0.99 71
1 1.00 0.95 0.98 43
accuracy 0.98 114
macro avg 0.99 0.98 0.98 114
weighted avg 0.98 0.98 0.98 114